Site cover image

Site icon imageSen(Qian)’s Memo

This website is Donglin Qian (Torin Sen)’s memo, especially about machine learning papers and competitive programming.

2024-KDD-[PTLoss]Knowledge Distillation with Perturbed Loss: From a Vanilla Teacher to a Proxy Teacher

https://openreview.net/forum?id=p14iRzavpt

Introduction

大きなモデルの各サンプルについて、できるだけ推論能力を保ちつつもモデルを小さくする手法として、Knowledge Distillationがある。通常のKDは訓練したモデルとできるだけ同じ分布の出力をするのを目的としているが、

  • 教師モデルの学習データのバイアス
  • 学習するときに教師モデルの誤った学習

などが原因でうまくない予測をしてしまうことがある。

これに、温度というパラメタを導入して、TTが1に近いならばlogitがそのまま使われるが、TTを大きくすると分布がなだらかになり、クラス間の確率差が小さくする、という手法がある。

pscaled=exp(z/T)iexp(zi/T)p_\text{scaled} = \frac{\exp(z / T)}{\sum_i \exp(z_i / T)}

しかし、このTTを調節することで得られる効果に限りがあり(なだらかにするだけでは問題を十分に解決できない)し、グリッドサーチでハイパラ調整しないといけない。

この論文では、通常使われているKL-divergenceにマクローリン展開を施し、1次項を摂動として加えるようなKL-Divergenceの改良によって改善を試みる。これは数学的にも正しいとわかっている。

事前知識

一般的な知識蒸留のやり方

Ground Truthラベルがない訓練データセットDDと、訓練済みの教師モデルptp^tが与えられる。これをもとに、学生モデルpsp^sを訓練する。

この時、以下のようにKLダイバージェンス(経験的に計算)を最小化するのが目的である。

Image in a image block

提案手法

教師モデルの出力はground truthのラベルの分布からずれているので、知識蒸留ではKL Divergenceを単に使うだけではうまくいかない。

これを緩めるために、以下のようにKL Divergenceで計算されるlog\logを、以下のようにマクローリン展開して、各次元の係数1/m1/mに摂動ϵm\epsilon_mを加えることで、バイアスがかかってもうまくとらえられるのではないか?

log(x)=m=1(1x)mmlog(x)m=1(1m+ϵm)(1x)m\log (x) = - \sum_{m =1}^\infty \frac{(1 - x)^m}{m} \to \log (x) \approx - \sum_{m =1}^\infty (\frac{1}{m} + \epsilon_m)(1 - x)^m

この摂動を加えた対数関数をKL Divergenceの中の「交差エントロピー」で新たに使わせた、新しいKL DivergenceをLossとして、知識蒸留を行う。(エントロピーの部分は据え置く)

KL Divergenceは負のエントロピーと交差エントロピーの差なので、経験的に書き直すと以下のようなる。

ここで、pct,pcsp_c^t, p_c^sは教師モデル、生徒モデルのそれぞれクラスccに対しての予測確率。

KL(pt(xn)ps(xn))=H(pt(xn))cCpct(xn)logpcs(xn)H(pt(xn))cCpct(xn){logpcs(xn)m=1ϵc,m(1pcs(xn))m}=KL(pt(xn)ps(xn))+cCpct(xn)m=1ϵc,m(1pcs(xn))m=lPT(pt(xn)ps(xn))KL(p^t(x_n) || p^s(x_n)) = -H(p^t(x_n)) - \sum_{c \in C}p_c^t(x_n) \log p_c^s(x_n) \\ \approx -H(p^t(x_n)) - \sum_{c \in C}p_c^t(x_n) \{ \log p_c^s(x_n) - \sum_{m=1}^\infty \epsilon_{c, m} (1 - p_c^s (x_n))^m\} \\ = KL(p^t(x_n) || p^s(x_n)) + \sum_{c \in C}p_c^t(x_n) \sum_{m=1}^\infty \epsilon_{c, m} (1 - p_c^s (x_n))^m \\ = l_{PT}(p^t(x_n) || p^s(x_n))

このように得たlPTl_{PT}がこの論文の提案したものである。

実際では、無限次数まで級数を計算せず、MMを設定してそこまで打ち切らせる。

提案手法は、以下のように教師のモデルの出力をKL Divergenceで評価するよりもさらに平滑化させることができるとわかる。

Image in a image block

そのうえ、以下のように摂動のϵ\epsilonを変更することで損失曲線を柔軟に操縦できるとわかる。

Image in a image block

提案手法の理論的根拠

教師モデルがリスクの上限に与える影響

Knowledge Distillationについての一般的な定理

  • ppが学生モデル。pp^*が真の分布。
  • 学生モデルppと教師モデルptp^tとUの分布DUD_Uについて、それの誤差の二乗は右辺で抑えられる。
  • 右辺の第一項は、Uデータの分布DUD_U上での教師モデルの出力pt(x)p^t(x)と真のラベル分布のp(x)p(x)のCross Entropyの分散である。
  • 右辺の第二項は、教師モデルの出力分布と真の分布の2乗損失と、教師モデルのエントロピー
Image in a image block

この結果からは、NUN_Uの数を大きくすると上界は小さくなるが、すべてを支配してるわけではない。特に真のモデルと教師モデルの分布の差は埋まりようがない。他には教師モデル(=真の分布)のエントロピーが大きければそれだけで上界が大きくなる

KL Divergence下での提案手法の妥当性

上の定理で、NuN_uを大きくすると、二項目のみ残るので、二項目を小さくしたい。

PT Lossを用いて計算することで、KL DivergenceがPT Lossとなるような教師分布ppxp^{px}という代理教師分布の存在を仮定する。(あるかもわからない)

あるかもわからないので、現実では以下のように最適化問題を解いていく。PT Lossで計算された値R~PTM\tilde{R}_{PT-M}と、代理教師の出力分布ppxp^{px}と生徒のKL Divergenceができるだけ一致させるような、ppxp^{px}を求める

下の数式で、RPTM(ps,pt)R_{PT-M}(p^s, p^t)とあるが、これは逆でRPTM(pt,ps)R_{PT-M}(p^t, p^s)という配置のほうが正しい

Image in a image block

これはこのままでは解けないので、一つだけ妥協をする。ppxp^{px}を動かして最小化するというが、これが難しいので、まず「代理教師も学生は近くなるだろう」という仮定のもと、ps=ppxp^{s} = p^{px}として代入をする。

RPTM(pt,ppx)=KL(pt,ppx)+cpct(xn)mϵc,m(1pcpx)RKL(ppx,ppx)=KL(ppx,ppx)=0R_{PT-M}(p^{t}, p^{px}) = KL(p^{t}, p^{px}) + \sum_c p_c^{t}(x_n) \sum_m \epsilon_{c, m} (1 - p_c^{px})\\ R_{KL}(p^{px}, p^{px}) = KL(p^{px}, p^{px}) = 0

Image in a image block

実際に最小化でppxp^{px}を得るのは、すべての教師モデルの出力のlogitがあるので、それに対してこれを解くように数値計算で解く

ϵ\epsilonの選び方

ϵ\epsilonは結論から言うと、ランダムに選ぶ。ランダムに選んだうえで、先ほどの定理のNuN_uに依存しない項に対して計算し、最も小さい値をとるものとする。

Image in a image block

Experiments

  • log\logM=5M=5まで展開しているらしい。